{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T09:31:55.502155Z",
     "iopub.status.busy": "2022-03-16T09:31:55.501641Z",
     "iopub.status.idle": "2022-03-16T09:31:55.627189Z",
     "shell.execute_reply": "2022-03-16T09:31:55.626017Z",
     "shell.execute_reply.started": "2022-03-16T09:31:55.502023Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-rw-rw-r-- 1 lyz lyz  91M 3月   1 15:25 ./input/corpus.tsv\n",
      "-rw-rw-r-- 1 lyz lyz 1.5K 3月   9 10:35 ./input/data_check.py\n",
      "-rw-rw-r-- 1 lyz lyz  26K 3月  10 14:12 ./input/dev.query.txt\n",
      "-rw-rw-r-- 1 lyz lyz 1.3M 3月  10 14:13 ./input/qrels.train.tsv\n",
      "-rw-rw-r-- 1 lyz lyz 2.3M 3月  10 14:13 ./input/train.query.txt\n"
     ]
    }
   ],
   "source": [
    "!ls ./input/* -lh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 数据集读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:27:34.164408Z",
     "iopub.status.busy": "2022-03-16T10:27:34.163847Z",
     "iopub.status.idle": "2022-03-16T10:27:34.170036Z",
     "shell.execute_reply": "2022-03-16T10:27:34.169188Z",
     "shell.execute_reply.started": "2022-03-16T10:27:34.164350Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:26:41.645534Z",
     "iopub.status.busy": "2022-03-16T10:26:41.645124Z",
     "iopub.status.idle": "2022-03-16T10:26:43.097631Z",
     "shell.execute_reply": "2022-03-16T10:26:43.097141Z",
     "shell.execute_reply.started": "2022-03-16T10:26:41.645487Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "corpus_data = pd.read_csv( \"./input/corpus.tsv\", sep=\"\\t\", names=[\"doc\", \"title\"])\n",
    "dev_data = pd.read_csv(\"./input/dev.query.txt\", sep=\"\\t\", names=[\"query\", \"title\"])\n",
    "train_data = pd.read_csv(\"./input/train.query.txt\", sep=\"\\t\", names=[\"query\", \"title\"])\n",
    "qrels = pd.read_csv(\"./input/qrels.train.tsv\", sep=\"\\t\", names=[\"query\", \"doc\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:26:44.529550Z",
     "iopub.status.busy": "2022-03-16T10:26:44.529012Z",
     "iopub.status.idle": "2022-03-16T10:26:44.568556Z",
     "shell.execute_reply": "2022-03-16T10:26:44.568044Z",
     "shell.execute_reply.started": "2022-03-16T10:26:44.529501Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "corpus_data = corpus_data.set_index(\"doc\")\n",
    "dev_data = dev_data.set_index(\"query\")\n",
    "train_data = train_data.set_index(\"query\")\n",
    "qrels = qrels.set_index(\"query\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:26:45.564253Z",
     "iopub.status.busy": "2022-03-16T10:26:45.563840Z",
     "iopub.status.idle": "2022-03-16T10:26:45.574448Z",
     "shell.execute_reply": "2022-03-16T10:26:45.573772Z",
     "shell.execute_reply.started": "2022-03-16T10:26:45.564207Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>doc</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>query</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>679139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>35343</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>781652</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>557516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>588014</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          doc\n",
       "query        \n",
       "1      679139\n",
       "2       35343\n",
       "3      781652\n",
       "4      557516\n",
       "5      588014"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qrels.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:26:46.860690Z",
     "iopub.status.busy": "2022-03-16T10:26:46.860281Z",
     "iopub.status.idle": "2022-03-16T10:26:46.900672Z",
     "shell.execute_reply": "2022-03-16T10:26:46.900195Z",
     "shell.execute_reply.started": "2022-03-16T10:26:46.860644Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "美赞臣亲舒一段 \t 领券满减】美赞臣安婴儿A+亲舒 婴儿奶粉1段850克 0-12个月宝宝\n",
      "慱朗手动料理机 \t Braun/博朗 MQ3035/3000/5025料理棒手持小型婴儿辅食家用搅拌机\n",
      "電力貓 \t 小米WiFi电力猫无线路由器套装一对300M穿墙宝家用信号增强扩展器\n",
      "掏夹缝工具 \t 电梯地坎清洁工具除灰尘神器轿厢门槽缝隙掏勺维保打扫奥的斯三菱\n",
      "飞推vip \t 飞逗推拍 店主邀请码 去水印 创意视频一键制作视频\n",
      "多功能托地把 \t 免手洗拖把家用一拖净刮刮乐干湿两用懒人拖平板墩布托帕拖地神器\n",
      "充气浮力袖 \t 学游泳神器装备充气腰背漂水袖浮臂三角浮力儿童游泳辅助工具大人\n",
      "盒马花胶鸡汤锅 \t 盒马鲜生工坊代购 花胶奶冻150g 入口Q弹 奶味浓郁 香甜丝滑\n",
      "塞塞乐 \t 婴儿童玩具6个月以上8宝宝益智早教0一1岁男孩女孩六9月十7新生礼\n",
      "广汽传祺gs5挡风遮雨条子 \t 2021款广汽传祺GS5晴雨挡遮雨板传奇GS5配件车窗雨眉防雨条挡雨板\n",
      "冰墩敦人偶服装 \t 灯笼布偶熊猫道具服装吉祥物人穿玩偶宣传传单服活动人偶服装\n",
      "寵物罐頭密封蓋 \t 仁可宠物 猫罐头保鲜盒密封盖防潮可加热猫咪罐头勺喂食勺猫用品\n",
      "15 蒸汽爱美克闸阀 \t 埃美柯 8135闸阀304不锈钢蒸汽用闸阀中型Z15W-16P耐温腐蚀4分6分\n",
      "电动切面机 \t 复兴牌面条机电动家用不锈钢压面机多功能半全自动四种面条DMT-6\n",
      "医用震动排痰机 \t 普门排痰机振动背心式慢阻肺支气管扩张肺气肿医用咳痰祛痰神器\n",
      "草莓盆专用夹 \t 大棚草莓钩盆器新款农具草莓采摘神器摘取自如温室水果铁丝钩子\n",
      "lg洗烘套装 \t LG RC90V9AV2W RC90V9JV2W RH10进口9/10KG热泵双变频干衣烘干机\n",
      "芝士脆 \t 山居小食 芝士小脆棒 香酥小零食 罐装 110g包邮\n",
      "笔记本应用书籍 \t ThinkPad笔记本电脑应用技术精粹\n"
     ]
    }
   ],
   "source": [
    "for idx in range(1, 20):\n",
    "    print(\n",
    "        train_data.loc[idx][\"title\"],\n",
    "        \"\\t\",\n",
    "        corpus_data.loc[qrels.loc[idx].ravel()[0]][\"title\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:26:49.089524Z",
     "iopub.status.busy": "2022-03-16T10:26:49.089107Z",
     "iopub.status.idle": "2022-03-16T10:26:49.875763Z",
     "shell.execute_reply": "2022-03-16T10:26:49.875249Z",
     "shell.execute_reply.started": "2022-03-16T10:26:49.089478Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache /tmp/jieba.cache\n",
      "Loading model cost 0.593 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'美赞臣 亲舒 一段'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jieba\n",
    "\n",
    "\" \".join(jieba.cut(\"美赞臣亲舒一段\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:26:51.911386Z",
     "iopub.status.busy": "2022-03-16T10:26:51.910971Z",
     "iopub.status.idle": "2022-03-16T10:27:26.085249Z",
     "shell.execute_reply": "2022-03-16T10:27:26.084819Z",
     "shell.execute_reply.started": "2022-03-16T10:26:51.911339Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.519 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "Loading model cost 0.523 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "Loading model cost 0.527 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "Loading model cost 0.530 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "def title_cut(x):\n",
    "    return list(jieba.cut(x))\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "corpus_title = Parallel(n_jobs=4)(delayed(title_cut)(title) for title in corpus_data[\"title\"])\n",
    "train_title = Parallel(n_jobs=4)(delayed(title_cut)(title) for title in train_data[\"title\"])\n",
    "dev_title = Parallel(n_jobs=4)(delayed(title_cut)(title) for title in dev_data[\"title\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:28:08.166432Z",
     "iopub.status.busy": "2022-03-16T10:28:08.165880Z",
     "iopub.status.idle": "2022-03-16T10:28:20.530572Z",
     "shell.execute_reply": "2022-03-16T10:28:20.529850Z",
     "shell.execute_reply.started": "2022-03-16T10:28:08.166383Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from gensim.models import Word2Vec\n",
    "from gensim.test.utils import common_texts\n",
    "\n",
    "\n",
    "if os.path.exists(\"word2vec.model\"):\n",
    "    model = Word2Vec.load(\"word2vec.model\")\n",
    "else: \n",
    "    model = Word2Vec(\n",
    "        sentences=list(corpus_title) + list(train_title) + list(dev_title),\n",
    "        vector_size=128,\n",
    "        window=5,\n",
    "        min_count=1,\n",
    "        workers=4,\n",
    "    )\n",
    "    model.save(\"word2vec.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:28:23.219871Z",
     "iopub.status.busy": "2022-03-16T10:28:23.219453Z",
     "iopub.status.idle": "2022-03-16T10:28:23.734125Z",
     "shell.execute_reply": "2022-03-16T10:28:23.733746Z",
     "shell.execute_reply.started": "2022-03-16T10:28:23.219825Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('金羚', 0.8855313658714294),\n",
       " ('海尔', 0.8832445740699768),\n",
       " ('韩电', 0.865885317325592),\n",
       " ('三洋', 0.8579813838005066),\n",
       " ('惠而浦', 0.8569628596305847),\n",
       " ('波轮', 0.8547092080116272),\n",
       " ('吉德', 0.8522803783416748),\n",
       " ('容声', 0.8349811434745789),\n",
       " ('xqb50', 0.8298338055610657),\n",
       " ('荣事达', 0.8265526294708252)]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.most_similar(\"小天鹅\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:28:25.613158Z",
     "iopub.status.busy": "2022-03-16T10:28:25.612741Z",
     "iopub.status.idle": "2022-03-16T10:28:25.621511Z",
     "shell.execute_reply": "2022-03-16T10:28:25.620497Z",
     "shell.execute_reply.started": "2022-03-16T10:28:25.613109Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[' ', '新款', '女', '/', '2021', '-', '加厚', '儿童', '秋冬', '外套']"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.index_to_key[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:28:27.330688Z",
     "iopub.status.busy": "2022-03-16T10:28:27.330140Z",
     "iopub.status.idle": "2022-03-16T10:28:27.335845Z",
     "shell.execute_reply": "2022-03-16T10:28:27.335362Z",
     "shell.execute_reply.started": "2022-03-16T10:28:27.330636Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.key_to_index[\"女\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:28:29.165728Z",
     "iopub.status.busy": "2022-03-16T10:28:29.165210Z",
     "iopub.status.idle": "2022-03-16T10:28:35.649410Z",
     "shell.execute_reply": "2022-03-16T10:28:35.648893Z",
     "shell.execute_reply.started": "2022-03-16T10:28:29.165676Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_w2v_ids = [[model.wv.key_to_index[xx] for xx in x] for x in train_title]\n",
    "corpus_w2v_ids = [[model.wv.key_to_index[xx] for xx in x] for x in corpus_title]\n",
    "dev_w2v_ids = [[model.wv.key_to_index[xx] for xx in x] for x in dev_title]\n",
    "\n",
    "# all_text = \" \".join(train_data[\"title\"])\n",
    "# all_query_word = list(jieba.cut(all_text))\n",
    "# all_query_word = [x for x in all_query_word if len(x) >= 2]\n",
    "# all_query_ids = [model.wv.key_to_index[xx] for xx in all_query_word]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:28:51.429567Z",
     "iopub.status.busy": "2022-03-16T10:28:51.429040Z",
     "iopub.status.idle": "2022-03-16T10:29:01.683285Z",
     "shell.execute_reply": "2022-03-16T10:29:01.682829Z",
     "shell.execute_reply.started": "2022-03-16T10:28:51.429512Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TfidfVectorizer(analyzer=<function <lambda> at 0x7f503c611a60>)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "\n",
    "idf = TfidfVectorizer(analyzer=lambda x: x)\n",
    "idf.fit(train_title + corpus_title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:01.684529Z",
     "iopub.status.busy": "2022-03-16T10:29:01.684360Z",
     "iopub.status.idle": "2022-03-16T10:29:01.695529Z",
     "shell.execute_reply": "2022-03-16T10:29:01.695196Z",
     "shell.execute_reply.started": "2022-03-16T10:29:01.684511Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 2.46292242,  8.5771301 ,  7.7050655 , ..., 14.21903717,\n",
       "        14.21903717, 14.21903717]),\n",
       " 640554)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idf.idf_, len(idf.idf_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:01.696310Z",
     "iopub.status.busy": "2022-03-16T10:29:01.696140Z",
     "iopub.status.idle": "2022-03-16T10:29:02.736233Z",
     "shell.execute_reply": "2022-03-16T10:29:02.735737Z",
     "shell.execute_reply.started": "2022-03-16T10:29:01.696292Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/.local/lib/python3.9/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.\n",
      "  warnings.warn(msg, category=FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "token = np.array(idf.get_feature_names())\n",
    "drop_token = token[np.where(idf.idf_ < 10)[0]]\n",
    "drop_token = list(set(drop_token))\n",
    "drop_token += ['领券']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:02.737630Z",
     "iopub.status.busy": "2022-03-16T10:29:02.737458Z",
     "iopub.status.idle": "2022-03-16T10:29:02.744462Z",
     "shell.execute_reply": "2022-03-16T10:29:02.744064Z",
     "shell.execute_reply.started": "2022-03-16T10:29:02.737611Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "drop_token_ids = [model.wv.key_to_index[x] for x in drop_token]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:02.745287Z",
     "iopub.status.busy": "2022-03-16T10:29:02.745093Z",
     "iopub.status.idle": "2022-03-16T10:29:02.826654Z",
     "shell.execute_reply": "2022-03-16T10:29:02.825709Z",
     "shell.execute_reply.started": "2022-03-16T10:29:02.745270Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[11.083542956587955, 13.52588999195716, 10.621724911928661]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[idf.idf_[idf.vocabulary_[xx]] for xx in train_title[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 句子编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:55.168273Z",
     "iopub.status.busy": "2022-03-16T10:29:55.167742Z",
     "iopub.status.idle": "2022-03-16T10:29:55.177747Z",
     "shell.execute_reply": "2022-03-16T10:29:55.177089Z",
     "shell.execute_reply.started": "2022-03-16T10:29:55.168220Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def unsuper_w2c_encoding(s, pooling=\"max\"):\n",
    "    feat = []\n",
    "    corpus_query_word = [x for x in s if x not in drop_token_ids]\n",
    "    if len(corpus_query_word) == 0:\n",
    "        return np.zeros(128)\n",
    "    \n",
    "    feat = model.wv[corpus_query_word]\n",
    "\n",
    "    if pooling == \"max\":\n",
    "        return np.array(feat).max(0)\n",
    "    if pooling == \"avg\":\n",
    "        return np.array(feat).mean(0)\n",
    "\n",
    "\n",
    "# def unsuper_w2c_encoding(s, pooling=\"avg\", debug=False):\n",
    "#     feat = []\n",
    "    \n",
    "#     # corpus_query_word = list(set(s) & set(all_query_ids))\n",
    "\n",
    "#     for w in s:\n",
    "        \n",
    "#         if idf.idf_[idf.vocabulary_[w]] > 11:\n",
    "#             if debug:\n",
    "#                 print(w)\n",
    "#             feat.append(model.wv[w])\n",
    "        \n",
    "#     if len(feat) == 0:\n",
    "#         return np.zeros(128)\n",
    "\n",
    "\n",
    "#     if pooling == \"max\":\n",
    "#         return np.array(feat).max(0)\n",
    "#     if pooling == \"avg\":\n",
    "#         return np.array(feat).mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:56.399446Z",
     "iopub.status.busy": "2022-03-16T10:29:56.398913Z",
     "iopub.status.idle": "2022-03-16T10:29:56.407825Z",
     "shell.execute_reply": "2022-03-16T10:29:56.407213Z",
     "shell.execute_reply.started": "2022-03-16T10:29:56.399388Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.13844971,  0.02329931,  0.00218298,  0.30193067,  0.24012831,\n",
       "       -0.03772005,  0.06984258, -0.00104502,  0.2525359 ,  0.02744666,\n",
       "        0.43275502,  0.03205966,  0.26791453, -0.04520423,  0.09708406,\n",
       "        0.03066566, -0.04605221, -0.02081237, -0.03759015,  0.01533887,\n",
       "        0.17473975,  0.54764605, -0.03636033, -0.0470927 ,  0.06880933,\n",
       "        0.30155945, -0.03811692,  0.4881905 , -0.02433527, -0.03787201,\n",
       "       -0.05090072,  0.1707239 , -0.04582863,  0.26897642,  0.18457308,\n",
       "        0.2454911 ,  0.01931062,  0.482265  ,  0.23519492, -0.01226895,\n",
       "        0.28913426,  0.46038544,  0.27555916,  0.05334978,  0.0333859 ,\n",
       "        0.17673498, -0.04632355, -0.01765269,  0.40820375,  0.14650379,\n",
       "        0.38070896,  0.41259366,  0.06576917,  0.8626872 ,  0.23650156,\n",
       "        0.0819394 ,  0.49214628, -0.02028468, -0.03448511, -0.00941355,\n",
       "       -0.02544864, -0.02463404,  0.1276301 ,  0.1326973 ,  0.30747333,\n",
       "       -0.01177661,  0.27607018,  0.42989215, -0.02167742, -0.07949985,\n",
       "        0.10563007, -0.04022123, -0.01270665, -0.05555258,  0.09551235,\n",
       "       -0.04936175,  0.00177826, -0.03438317, -0.05237317,  0.1167613 ,\n",
       "       -0.00116845,  0.13330679,  0.2522552 ,  0.60099244,  0.25534162,\n",
       "        0.11306063,  0.27569652, -0.07508233,  0.17610143,  0.353241  ,\n",
       "        0.21001127, -0.03868374,  0.25076875,  0.00679028,  0.21244186,\n",
       "       -0.02621726, -0.03703961, -0.08610518, -0.0076225 , -0.01768565,\n",
       "       -0.05622293, -0.03791003,  0.14516251,  0.4334494 ,  0.03873612,\n",
       "        0.01481874, -0.01375509,  0.04263891,  0.21438226,  0.06256728,\n",
       "        0.02565995, -0.02784491,  0.10379297,  0.00886007, -0.03401844,\n",
       "        0.58619744,  0.26690742,  0.21823981,  0.21963592,  0.0816617 ,\n",
       "       -0.02957517, -0.04989241, -0.02675758,  0.1876288 , -0.0081165 ,\n",
       "       -0.05108411, -0.03255492,  0.02651718], dtype=float32)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unsuper_w2c_encoding(train_w2v_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:29:57.359402Z",
     "iopub.status.busy": "2022-03-16T10:29:57.358952Z",
     "iopub.status.idle": "2022-03-16T10:29:57.368148Z",
     "shell.execute_reply": "2022-03-16T10:29:57.367230Z",
     "shell.execute_reply.started": "2022-03-16T10:29:57.359356Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "_ = unsuper_w2c_encoding(corpus_w2v_ids[679139-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:30:02.167603Z",
     "iopub.status.busy": "2022-03-16T10:30:02.167113Z",
     "iopub.status.idle": "2022-03-16T10:30:03.905367Z",
     "shell.execute_reply": "2022-03-16T10:30:03.905001Z",
     "shell.execute_reply.started": "2022-03-16T10:30:02.167552Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_5053/2441499990.py:4: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  unsuper_w2c_encoding(s) for s in tqdm_notebook(corpus_w2v_ids[:1000] + [corpus_w2v_ids[x] for x in qrels['doc'].values[:100] - 1])\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69065c5a389f433bba9474c64601296e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_5053/2441499990.py:9: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  unsuper_w2c_encoding(s) for s in tqdm_notebook(train_w2v_ids[:100])\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "67c06c511ca04a61afaeedd9b617ab4d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_5053/2441499990.py:14: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  unsuper_w2c_encoding(s) for s in tqdm_notebook(dev_w2v_ids[:100])\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30180de292e8400e9d5870b19f1671d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import tqdm_notebook\n",
    "# [corpus_w2v_ids[x] for x in qrels['doc'].values[:100] - 1]\n",
    "\n",
    "corpus_mean_feat = [\n",
    "    unsuper_w2c_encoding(s) for s in tqdm_notebook(corpus_w2v_ids[:1000])\n",
    "]\n",
    "corpus_mean_feat = np.vstack(corpus_mean_feat)\n",
    "\n",
    "train_mean_feat = [\n",
    "    unsuper_w2c_encoding(s) for s in tqdm_notebook(train_w2v_ids[:100])\n",
    "]\n",
    "train_mean_feat = np.vstack(train_mean_feat)\n",
    "\n",
    "dev_mean_feat = [\n",
    "    unsuper_w2c_encoding(s) for s in tqdm_notebook(dev_w2v_ids[:100])\n",
    "]\n",
    "dev_mean_feat = np.vstack(dev_mean_feat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初步检索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:30:07.168338Z",
     "iopub.status.busy": "2022-03-16T10:30:07.167788Z",
     "iopub.status.idle": "2022-03-16T10:30:07.176152Z",
     "shell.execute_reply": "2022-03-16T10:30:07.175637Z",
     "shell.execute_reply.started": "2022-03-16T10:30:07.168289Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "corpus_mean_feat = normalize(corpus_mean_feat)\n",
    "train_mean_feat = normalize(train_mean_feat)\n",
    "dev_mean_feat = normalize(dev_mean_feat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:30:42.364737Z",
     "iopub.status.busy": "2022-03-16T10:30:42.364310Z",
     "iopub.status.idle": "2022-03-16T10:30:42.399839Z",
     "shell.execute_reply": "2022-03-16T10:30:42.399106Z",
     "shell.execute_reply.started": "2022-03-16T10:30:42.364690Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_5053/560909120.py:2: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for idx in tqdm_notebook(range(1, 100)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95617ec96ae140228a8a66e21d03bc60",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/99 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mrr = []\n",
    "for idx in tqdm_notebook(range(1, 100)):\n",
    "    dis = np.dot(train_mean_feat[idx - 1], corpus_mean_feat.T)\n",
    "    ids = np.argsort(dis)[::-1]\n",
    "    \n",
    "    # print(train_title[idx-1], corpus_data.loc[qrels.loc[idx].ravel()[0]][\"title\"],  dis[qrels.loc[idx].ravel()-1])\n",
    "    # print(corpus_title[ids[0]])\n",
    "    # mrr.append(1/(np.where(ids == qrels.loc[idx].ravel()[0] - 1)[0][0] + 1))\n",
    "    # break\n",
    "    # print('')\n",
    "    # mrr.append(ids[0]==idx+999)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-16T10:30:43.165139Z",
     "iopub.status.busy": "2022-03-16T10:30:43.164600Z",
     "iopub.status.idle": "2022-03-16T10:30:43.216282Z",
     "shell.execute_reply": "2022-03-16T10:30:43.215020Z",
     "shell.execute_reply.started": "2022-03-16T10:30:43.165090Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.18181818181818182"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(mrr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-03-15T17:54:17.200972Z",
     "iopub.status.busy": "2022-03-15T17:54:17.200544Z",
     "iopub.status.idle": "2022-03-15T17:55:43.230239Z",
     "shell.execute_reply": "2022-03-15T17:55:43.229069Z",
     "shell.execute_reply.started": "2022-03-15T17:54:17.200925Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('query_embedding', 'w') as up :\n",
    "    for id, feat in zip(dev_data.index, dev_mean_feat):\n",
    "        up.write('{0}\\t{1}\\n'.format(id, ','.join([str(x)[:6] for x in feat])))\n",
    "        \n",
    "with open('doc_embedding', 'w') as up :\n",
    "    for id, feat in zip(corpus_data.index, corpus_mean_feat):\n",
    "        up.write('{0}\\t{1}\\n'.format(id, ','.join([str(x)[:6] for x in feat])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
